{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f1029e76-7e61-4296-801d-3a49637b5972",
   "metadata": {},
   "source": [
    "# Statistical Evaluation\n",
    "\n",
    "Evaluating model output with NLP statistical methods.\n",
    "\n",
    "Add `prettytable` to make a nice looking output."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64f96a0a-bb6e-4dbf-a37a-1f602429fb47",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip3 install prettytable==3.10.2 sacrebleu==2.4.2 rouge-score==0.1.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63e0f11c-210e-4540-b31b-c9b48979f46c",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip3 show prettytable sacrebleu rouge-score "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3650ce4e-c63a-4f96-968f-bffbfd030d3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from prettytable import PrettyTable\n",
    "import sacrebleu\n",
    "from rouge_score import rouge_scorer\n",
    "\n",
    "evaluation_data = [\n",
    "    {\n",
    "        \"input\": \"What should I do in New York City in July?\",\n",
    "        \"output\": \"Check out Times Square, go to an outdoor concert, and visit the Statue of Liberty.\",\n",
    "        \"golden_answer\": \"Explore Central Park, attend outdoor concerts, and visit rooftop bars.\",\n",
    "        \"contexts\": [\n",
    "            \"Times Square is known for its Broadway theaters, bright lights, and bustling atmosphere.\",\n",
    "            \"Outdoor concerts in Central Park are popular summer events attracting many visitors.\",\n",
    "            \"The Statue of Liberty is a symbol of freedom and a must-see landmark in NYC.\"\n",
    "        ]\n",
    "    },\n",
    "    {\n",
    "        \"input\": \"Can you help me with my math homework?\",\n",
    "        \"output\": \"I'm designed to assist with travel queries. For math help, try using online resources like Khan Academy or Mathway.\",\n",
    "        \"golden_answer\": \"I am a travel assistant chatbot, so I cannot help you with your math homework.\",\n",
    "        \"contexts\": []\n",
    "    },\n",
    "    {\n",
    "        \"input\": \"What’s the capital of France?\",\n",
    "        \"output\": \"The capital of France is Paris.\",\n",
    "        \"golden_answer\": \"Paris is the capital of France.\",\n",
    "        \"contexts\": [\n",
    "            \"Paris, known as the City of Light, is the most populous city of France.\",\n",
    "            \"European capitals: Paris, France; Berlin, Germany; Madrid, Spain\",\n",
    "        ]\n",
    "    }\n",
    "]\n",
    "\n",
    "# Statistical evaluators\n",
    "def evaluate_bleu(output, golden_answer):\n",
    "    bleu = sacrebleu.corpus_bleu([output], [[golden_answer]])\n",
    "    return bleu.score / 100  # Normalize BLEU score to be between 0 and 1\n",
    "\n",
    "def evaluate_rouge(output, contexts):\n",
    "    context_text = (\"\\n\").join(contexts)\n",
    "    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)\n",
    "    scores = scorer.score(context_text, output)\n",
    "    return scores['rougeL'].fmeasure\n",
    "\n",
    "\n",
    "def calculate_average(scores):\n",
    "    return sum(scores) / len(scores)\n",
    "\n",
    "# truncate strings for easier printing in table\n",
    "def truncate_string(s, max_length=10):\n",
    "    return (s[:max_length] + '...') if len(s) > max_length else s\n",
    "\n",
    "def create_table(data):\n",
    "    table = PrettyTable()\n",
    "    table.field_names = [\"Input\", \"Output\", \"Golden Answer\", \"# Contexts\", \"BLEU\", \"ROGUE\"]\n",
    "    \n",
    "    bleu_scores = [evaluate_bleu(case[\"output\"], case[\"golden_answer\"]) for case in data]\n",
    "    rouge_scores = [evaluate_rouge(case[\"output\"], case[\"contexts\"]) for case in data]\n",
    "    \n",
    "    for case, bleu, rouge in zip(data, bleu_scores, rouge_scores):\n",
    "        table.add_row([\n",
    "            truncate_string(case[\"input\"]),\n",
    "            truncate_string(case[\"output\"]),\n",
    "            truncate_string(case[\"golden_answer\"]),\n",
    "            len(case[\"contexts\"]),\n",
    "            f\"{bleu:.4f}\",\n",
    "            f\"{rouge:.4f}\"])\n",
    "    \n",
    "    # Add a blank row for visual separation\n",
    "    table.add_row([\"\", \"\", \"\", \"\", \"\", \"\"])\n",
    "    \n",
    "    # Add the average score to bottom of the table\n",
    "    average_bleu = calculate_average(bleu_scores)\n",
    "    average_rouge = calculate_average(rouge_scores)\n",
    "    \n",
    "    table.add_row([\"Average\", \"\", \"\", \"\", f\"{average_bleu:.4f}\", f\"{average_rouge:.4f}\"])\n",
    "    \n",
    "    return table\n",
    "\n",
    "# Create and print the table\n",
    "result_table = create_table(evaluation_data)\n",
    "print(result_table)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
